{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural structured learning with TensorFlow 2.0\n",
    "\n",
    "In this example, we'll demonstrate how to use [Neural Structured Learning](https://www.tensorflow.org/neural_structured_learning) (NSl) to augmented the training process of unstructured data with structured signals. We'll train the NSL model using the [CORA](https://relational.fit.cvut.cz/dataset/CORA) dataset. It contains scientific publications, separated in 7 categories. Each publication is represented with multi-hot encoded vector, which indicates all the words present in the publication text. This represents the unstructured portion of the dataset. Additionally, the dataset contains a graph of citations between the publications, which represents the structured portion. In this example, we'll use the multi-hot encoded publications as input to a classification network, which will classify each publication in its corresponding class. We'll also augment the training with the structured graph of citations. <br />\n",
    "_Fort more details, check out the book._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import neural_structured_learning as nsl\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define the parameters of the training and the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cora dataset path\n",
    "TRAIN_DATA_PATH = 'data/train_merged_examples.tfr'\n",
    "TEST_DATA_PATH = 'data/test_examples.tfr'\n",
    "# Constants used to identify neighbor features in the input.\n",
    "NBR_FEATURE_PREFIX = 'NL_nbr_'\n",
    "NBR_WEIGHT_SUFFIX = '_weight'\n",
    "# Dataset parameters\n",
    "NUM_CLASSES = 7\n",
    "MAX_SEQ_LENGTH = 1433\n",
    "# Number of neighbors to consider in the composite loss function\n",
    "NUM_NEIGHBORS = 1\n",
    "# Training parameters\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the pre-processing of the dataset, which is implemented with the `make_dataset` (on dataset level) and `parse_example` (on sample level) functions. The final result is an instance of `tf.data.TFRecordDataset`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_example(example_proto: tf.train.Example) -> tuple:\n",
    "    \"\"\"\n",
    "    Extracts relevant fields from the example_proto\n",
    "    :param example_proto: the example\n",
    "    :return: A data/label pair, where\n",
    "        data is a dictionary containing relevant features of the example\n",
    "    \"\"\"\n",
    "    # The 'words' feature is a multi-hot, bag-of-words representation of the\n",
    "    # original raw text. A default value is required for examples that don't\n",
    "    # have the feature.\n",
    "    feature_spec = {\n",
    "        'words':\n",
    "            tf.io.FixedLenFeature(shape=[MAX_SEQ_LENGTH],\n",
    "                                  dtype=tf.int64,\n",
    "                                  default_value=tf.constant(\n",
    "                                      value=0,\n",
    "                                      dtype=tf.int64,\n",
    "                                      shape=[MAX_SEQ_LENGTH])),\n",
    "        'label':\n",
    "            tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
    "    }\n",
    "    # We also extract corresponding neighbor features in a similar manner to\n",
    "    # the features above.\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
    "        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)\n",
    "        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(\n",
    "            shape=[MAX_SEQ_LENGTH],\n",
    "            dtype=tf.int64,\n",
    "            default_value=tf.constant(\n",
    "                value=0, dtype=tf.int64, shape=[MAX_SEQ_LENGTH]))\n",
    "\n",
    "        # We assign a default value of 0.0 for the neighbor weight so that\n",
    "        # graph regularization is done on samples based on their exact number\n",
    "        # of neighbors. In other words, non-existent neighbors are discounted.\n",
    "        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
    "            shape=[1], dtype=tf.float32, default_value=tf.constant([0.0]))\n",
    "\n",
    "    features = tf.io.parse_single_example(example_proto, feature_spec)\n",
    "\n",
    "    labels = features.pop('label')\n",
    "    return features, labels\n",
    "\n",
    "\n",
    "def make_dataset(file_path: str, training=False) -> tf.data.TFRecordDataset:\n",
    "    \"\"\"\n",
    "    Extracts relevant fields from the example_proto\n",
    "    :param file_path: name of the file in the `.tfrecord` format containing\n",
    "        `tf.train.Example` objects.\n",
    "    :param training: whether the dataset is for training\n",
    "    :return: tf.data.TFRecordDataset of tf.train.Example objects\n",
    "    \"\"\"\n",
    "    dataset = tf.data.TFRecordDataset([file_path])\n",
    "    if training:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.map(parse_example).batch(BATCH_SIZE)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll instantiate the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)\n",
    "test_dataset = make_dataset(TEST_DATA_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's build the model. It starts as a regular feedforward neural network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(dropout_rate):\n",
    "    \"\"\"Creates a sequential multi-layer perceptron model.\"\"\"\n",
    "    return tf.keras.Sequential([\n",
    "        # one-hot encoded input.\n",
    "        tf.keras.layers.InputLayer(\n",
    "            input_shape=(MAX_SEQ_LENGTH,), name='words'),\n",
    "\n",
    "        # 2 fully-connected layers + dropout\n",
    "        tf.keras.layers.Dense(64, activation='relu'),\n",
    "        tf.keras.layers.Dropout(dropout_rate),\n",
    "        tf.keras.layers.Dense(64, activation='relu'),\n",
    "        tf.keras.layers.Dropout(dropout_rate),\n",
    "\n",
    "        # Softmax output\n",
    "        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')\n",
    "    ])\n",
    "\n",
    "\n",
    "# Build a new base MLP model.\n",
    "model = build_model(dropout_rate=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll wrap the model with the graph regularization procedure, which allows to include the structured component of the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap the base MLP model with graph regularization.\n",
    "graph_reg_config = nsl.configs.make_graph_reg_config(\n",
    "    max_neighbors=NUM_NEIGHBORS,\n",
    "    multiplier=0.1,\n",
    "    distance_type=nsl.configs.DistanceType.L2,\n",
    "    sum_over_axis=-1)\n",
    "graph_reg_model = nsl.keras.GraphRegularization(model,\n",
    "                                                graph_reg_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now run the training for 100 epochs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1208 21:10:47.376572 139984963245888 deprecation.py:323] From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/array_grad.py:502: _EagerTensorBase.cpu (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.identity instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.9226 - accuracy: 0.1972 - graph_loss: 0.0078\n",
      "Epoch 2/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.8297 - accuracy: 0.2905 - graph_loss: 0.0126\n",
      "Epoch 3/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.7166 - accuracy: 0.3332 - graph_loss: 0.0269\n",
      "Epoch 4/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.5948 - accuracy: 0.3810 - graph_loss: 0.0463\n",
      "Epoch 5/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 1.4489 - accuracy: 0.4896 - graph_loss: 0.0745\n",
      "Epoch 6/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.2697 - accuracy: 0.5582 - graph_loss: 0.1225\n",
      "Epoch 7/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 1.1070 - accuracy: 0.6422 - graph_loss: 0.1673\n",
      "Epoch 8/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.9690 - accuracy: 0.6807 - graph_loss: 0.2088\n",
      "Epoch 9/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.8550 - accuracy: 0.7383 - graph_loss: 0.2381\n",
      "Epoch 10/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.7695 - accuracy: 0.7749 - graph_loss: 0.2603\n",
      "Epoch 11/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.6728 - accuracy: 0.8037 - graph_loss: 0.2745\n",
      "Epoch 12/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.6142 - accuracy: 0.8311 - graph_loss: 0.2858\n",
      "Epoch 13/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.5465 - accuracy: 0.8469 - graph_loss: 0.2891\n",
      "Epoch 14/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.5009 - accuracy: 0.8659 - graph_loss: 0.2998\n",
      "Epoch 15/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.4399 - accuracy: 0.8775 - graph_loss: 0.3071\n",
      "Epoch 16/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.4105 - accuracy: 0.8914 - graph_loss: 0.3093\n",
      "Epoch 17/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.3776 - accuracy: 0.9049 - graph_loss: 0.3141\n",
      "Epoch 18/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.3858 - accuracy: 0.9026 - graph_loss: 0.3104\n",
      "Epoch 19/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.3224 - accuracy: 0.9220 - graph_loss: 0.3131\n",
      "Epoch 20/100\n",
      "17/17 [==============================] - 1s 35ms/step - loss: 0.3085 - accuracy: 0.9234 - graph_loss: 0.3274\n",
      "Epoch 21/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.2857 - accuracy: 0.9341 - graph_loss: 0.3237\n",
      "Epoch 22/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.2774 - accuracy: 0.9364 - graph_loss: 0.3230\n",
      "Epoch 23/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.2554 - accuracy: 0.9364 - graph_loss: 0.3238\n",
      "Epoch 24/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.2215 - accuracy: 0.9559 - graph_loss: 0.3176\n",
      "Epoch 25/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.2166 - accuracy: 0.9485 - graph_loss: 0.3120\n",
      "Epoch 26/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.2086 - accuracy: 0.9545 - graph_loss: 0.3181\n",
      "Epoch 27/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.2023 - accuracy: 0.9578 - graph_loss: 0.3294\n",
      "Epoch 28/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1965 - accuracy: 0.9508 - graph_loss: 0.3311\n",
      "Epoch 29/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1950 - accuracy: 0.9559 - graph_loss: 0.3288\n",
      "Epoch 30/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1699 - accuracy: 0.9657 - graph_loss: 0.3234\n",
      "Epoch 31/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1584 - accuracy: 0.9703 - graph_loss: 0.3248\n",
      "Epoch 32/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.1656 - accuracy: 0.9666 - graph_loss: 0.3260\n",
      "Epoch 33/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.1710 - accuracy: 0.9657 - graph_loss: 0.3356\n",
      "Epoch 34/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1486 - accuracy: 0.9703 - graph_loss: 0.3258\n",
      "Epoch 35/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1464 - accuracy: 0.9666 - graph_loss: 0.3257\n",
      "Epoch 36/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1354 - accuracy: 0.9768 - graph_loss: 0.3470\n",
      "Epoch 37/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1391 - accuracy: 0.9754 - graph_loss: 0.3340\n",
      "Epoch 38/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1369 - accuracy: 0.9759 - graph_loss: 0.3286\n",
      "Epoch 39/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1220 - accuracy: 0.9800 - graph_loss: 0.3274\n",
      "Epoch 40/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.1323 - accuracy: 0.9717 - graph_loss: 0.3343\n",
      "Epoch 41/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1280 - accuracy: 0.9735 - graph_loss: 0.3384\n",
      "Epoch 42/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.1093 - accuracy: 0.9828 - graph_loss: 0.3266\n",
      "Epoch 43/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1075 - accuracy: 0.9828 - graph_loss: 0.3233\n",
      "Epoch 44/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.1070 - accuracy: 0.9828 - graph_loss: 0.3333\n",
      "Epoch 45/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0998 - accuracy: 0.9865 - graph_loss: 0.3313\n",
      "Epoch 46/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.1054 - accuracy: 0.9814 - graph_loss: 0.3367\n",
      "Epoch 47/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1055 - accuracy: 0.9824 - graph_loss: 0.3350\n",
      "Epoch 48/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.1020 - accuracy: 0.9824 - graph_loss: 0.3378\n",
      "Epoch 49/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0997 - accuracy: 0.9861 - graph_loss: 0.3247\n",
      "Epoch 50/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.1020 - accuracy: 0.9838 - graph_loss: 0.3274\n",
      "Epoch 51/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0977 - accuracy: 0.9865 - graph_loss: 0.3264\n",
      "Epoch 52/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0967 - accuracy: 0.9870 - graph_loss: 0.3292\n",
      "Epoch 53/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0996 - accuracy: 0.9879 - graph_loss: 0.3331\n",
      "Epoch 54/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0841 - accuracy: 0.9907 - graph_loss: 0.3201\n",
      "Epoch 55/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0914 - accuracy: 0.9865 - graph_loss: 0.3336\n",
      "Epoch 56/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0816 - accuracy: 0.9870 - graph_loss: 0.3420\n",
      "Epoch 57/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3371\n",
      "Epoch 58/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0756 - accuracy: 0.9912 - graph_loss: 0.3285\n",
      "Epoch 59/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0818 - accuracy: 0.9884 - graph_loss: 0.3359\n",
      "Epoch 60/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0864 - accuracy: 0.9852 - graph_loss: 0.3279\n",
      "Epoch 61/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0848 - accuracy: 0.9828 - graph_loss: 0.3221\n",
      "Epoch 62/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0798 - accuracy: 0.9926 - graph_loss: 0.3285\n",
      "Epoch 63/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0789 - accuracy: 0.9898 - graph_loss: 0.3317\n",
      "Epoch 64/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0791 - accuracy: 0.9875 - graph_loss: 0.3362\n",
      "Epoch 65/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0843 - accuracy: 0.9865 - graph_loss: 0.3336\n",
      "Epoch 66/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0716 - accuracy: 0.9912 - graph_loss: 0.3366\n",
      "Epoch 67/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0797 - accuracy: 0.9898 - graph_loss: 0.3331\n",
      "Epoch 68/100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0682 - accuracy: 0.9907 - graph_loss: 0.3278\n",
      "Epoch 69/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0753 - accuracy: 0.9879 - graph_loss: 0.3302\n",
      "Epoch 70/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0749 - accuracy: 0.9870 - graph_loss: 0.3237\n",
      "Epoch 71/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0662 - accuracy: 0.9930 - graph_loss: 0.3280\n",
      "Epoch 72/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0738 - accuracy: 0.9893 - graph_loss: 0.3371\n",
      "Epoch 73/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0795 - accuracy: 0.9879 - graph_loss: 0.3242\n",
      "Epoch 74/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0733 - accuracy: 0.9889 - graph_loss: 0.3344\n",
      "Epoch 75/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0649 - accuracy: 0.9935 - graph_loss: 0.3314\n",
      "Epoch 76/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0709 - accuracy: 0.9912 - graph_loss: 0.3315\n",
      "Epoch 77/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0694 - accuracy: 0.9907 - graph_loss: 0.3273\n",
      "Epoch 78/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0737 - accuracy: 0.9889 - graph_loss: 0.3302\n",
      "Epoch 79/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0658 - accuracy: 0.9916 - graph_loss: 0.3275\n",
      "Epoch 80/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0616 - accuracy: 0.9940 - graph_loss: 0.3345\n",
      "Epoch 81/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0572 - accuracy: 0.9968 - graph_loss: 0.3197\n",
      "Epoch 82/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0678 - accuracy: 0.9921 - graph_loss: 0.3378\n",
      "Epoch 83/100\n",
      "17/17 [==============================] - 1s 35ms/step - loss: 0.0632 - accuracy: 0.9926 - graph_loss: 0.3351\n",
      "Epoch 84/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0691 - accuracy: 0.9926 - graph_loss: 0.3288\n",
      "Epoch 85/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0616 - accuracy: 0.9935 - graph_loss: 0.3333\n",
      "Epoch 86/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0618 - accuracy: 0.9930 - graph_loss: 0.3332\n",
      "Epoch 87/100\n",
      "17/17 [==============================] - 1s 35ms/step - loss: 0.0667 - accuracy: 0.9912 - graph_loss: 0.3286\n",
      "Epoch 88/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0671 - accuracy: 0.9921 - graph_loss: 0.3377\n",
      "Epoch 89/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0662 - accuracy: 0.9930 - graph_loss: 0.3330\n",
      "Epoch 90/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0634 - accuracy: 0.9912 - graph_loss: 0.3268\n",
      "Epoch 91/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0654 - accuracy: 0.9907 - graph_loss: 0.3253\n",
      "Epoch 92/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0632 - accuracy: 0.9926 - graph_loss: 0.3249\n",
      "Epoch 93/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0557 - accuracy: 0.9963 - graph_loss: 0.3219\n",
      "Epoch 94/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0619 - accuracy: 0.9930 - graph_loss: 0.3279\n",
      "Epoch 95/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0559 - accuracy: 0.9963 - graph_loss: 0.3326\n",
      "Epoch 96/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0597 - accuracy: 0.9935 - graph_loss: 0.3293\n",
      "Epoch 97/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0552 - accuracy: 0.9972 - graph_loss: 0.3273\n",
      "Epoch 98/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0612 - accuracy: 0.9926 - graph_loss: 0.3293\n",
      "Epoch 99/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 0.0593 - accuracy: 0.9926 - graph_loss: 0.3279\n",
      "Epoch 100/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0634 - accuracy: 0.9949 - graph_loss: 0.3284\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f50401e20b8>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph_reg_model.compile(\n",
    "    optimizer='adam',\n",
    "    loss='sparse_categorical_crossentropy',\n",
    "    metrics=['accuracy'])\n",
    "\n",
    "# run eagerly to prevent epoch warnings\n",
    "graph_reg_model.run_eagerly = True\n",
    "\n",
    "graph_reg_model.fit(train_dataset, epochs=100, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can evaluate the results. As we can see, the model achieves more than 81% classification accuracy, which is an increase of around 3% compared to a model without graph regularization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5/5 [==============================] - 0s 17ms/step - loss: 1.0498 - accuracy: 0.8101 - graph_loss: 0.0000e+00\n",
      "Evaluation accuracy: 0.8101266026496887\n",
      "Evaluation loss: 1.0497615039348602\n",
      "Evaluation graph loss: 0.0\n"
     ]
    }
   ],
   "source": [
    "eval_results = dict(\n",
    "    zip(graph_reg_model.metrics_names,\n",
    "        graph_reg_model.evaluate(test_dataset)))\n",
    "print('Evaluation accuracy: {}'.format(eval_results['accuracy']))\n",
    "print('Evaluation loss: {}'.format(eval_results['loss']))\n",
    "print('Evaluation graph loss: {}'.format(eval_results['graph_loss']))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
